# Table 2 in "Evaluating (weighted) dynamic treatment effects by double machine learning" 
# by Hugo Bodory, Martin Huber, and Luk� Laff�rs

# To replicate the results for the 12 DGPs analyzed in this simulation study, the 
# variables numsim, n, p0, and s (lines 99-102) have to be adapted accordingly.

rm(list = ls())
# all packages used in the simulation study
packages = c('SuperLearner','e1071','glmnet','ranger','xgboost','tictoc','parallel','writexl','ggplot2','dplyr','mvtnorm','xtable')
lapply(packages,FUN=function(packages) {do.call("require", list(packages))}) 

# SuperLearner algorithms
MLfunct=function(y, x, d1=NULL, d2=NULL, MLmethod="lasso",  ybin=0){
  if (is.null(d1)==0 & is.null(d2)==0) { y=y[d1==1 & d2==1]; x=x[d1==1 & d2==1,]}
  if (is.null(d1)==0 & is.null(d2)==1) { y=y[d1==1]; x=x[d1==1,]}
  cl <- makeCluster(16, type = "PSOCK")  # can use different types here
  clusterSetRNGStream(cl, iseed = 2343)
  # make SuperLearner functions available on the clusters
  foo <- clusterEvalQ(cl, library(SuperLearner))  
  if  (MLmethod=="randomforest"){ 
    if (ybin==1) model=tryCatch(snowSuperLearner(cluster=cl,Y=y,X=x,family=binomial(),SL.library = "SL.ranger",verbose=TRUE),error=function(e) NaN) 
    if (ybin!=1) model=tryCatch(snowSuperLearner(cluster=cl,Y=y,X=x,family=gaussian(),SL.library = "SL.ranger",verbose=TRUE),error=function(e) NaN) 
  }
  if  (MLmethod=="lasso"){ 
    if (ybin==1) model=tryCatch(snowSuperLearner(cluster=cl,Y=y,X=x,family=binomial(),SL.library = "SL.glmnet",verbose=TRUE),error=function(e) NaN) 
    if (ybin!=1) model=tryCatch(snowSuperLearner(cluster=cl,Y=y,X=x,family=gaussian(),SL.library = "SL.glmnet",verbose=TRUE),error=function(e) NaN)
  }
  stopCluster(cl)  
  model        
}

# dynamic treatment effect estimation
dyntreatDML=function(y2,d1,d2,x0,x1, s=NULL, d1treat=1, d2treat=1, d1control=0, d2control=0,  trim=0.01, MLmethod="lasso", fewsplits=FALSE, repl=repl){
  if (length(d1treat)==1) {d1tre=1*(d1==d1treat)} else {d1tre=d1treat}
  if (length(d2treat)==1) {d2tre=1*(d2==d2treat)} else {d2tre=d2treat}
  if (length(d1control)==1) {d1con=1*(d1==d1control)} else {d1con=d1control}
  if (length(d2control)==1) {d2con=1*(d2==d2control)} else {d2con=d2control}
  scorestreat=hddyntreat(y2=y2,d1=d1tre,d2=d2tre,x0=x0,x1=x1, s=s, trim=trim, MLmethod=MLmethod, fewsplits=fewsplits, repl=repl, scoreid='treat')
  scorescontrol=hddyntreat(y2=y2,d1=d1con,d2=d2con,x0=x0,x1=x1, s=s, trim=trim, MLmethod=MLmethod, fewsplits=fewsplits, repl=repl, scoreid='control') 
  trimmed=1*(scorescontrol[,10]+scorestreat[,10]>0)  # number of trimmed observations
  scorestreat=scorestreat[trimmed==0,]
  scorescontrol=scorescontrol[trimmed==0,]
  tscores=(scorestreat[,1]*scorestreat[,2]*scorestreat[,3]*(scorestreat[,4]-scorestreat[,5])/(scorestreat[,6]*scorestreat[,7])+scorestreat[,1]*scorestreat[,2]*(scorestreat[,5]-scorestreat[,8])/scorestreat[,6]+scorestreat[,9]*scorestreat[,8])/mean(scorestreat[,9])
  cscores=(scorescontrol[,1]*scorescontrol[,2]*scorescontrol[,3]*(scorescontrol[,4]-scorescontrol[,5])/(scorescontrol[,6]*scorescontrol[,7])+scorescontrol[,1]*scorescontrol[,2]*(scorescontrol[,5]-scorescontrol[,8])/scorescontrol[,6]+scorescontrol[,9]*scorescontrol[,8])/mean(scorescontrol[,9])
  meantreat=mean(tscores)
  meancontrol=mean(cscores)
  effect=meantreat - meancontrol
  se=sqrt(mean((tscores-cscores-effect)^2)/length(tscores))
  pval= 2*pnorm((-1)*abs(effect/se))
  idx_treat_without_trimmed = d1tre[trimmed==0]&d2tre[trimmed==0]
  idx_control_without_trimmed = d1con[trimmed==0]&d2con[trimmed==0]
  list(effect=effect,se=se,pval=pval,ntrimmed=sum(trimmed),meantreat=meantreat,meancontrol=meancontrol,
       psd1treat=scorestreat[,6],psd2treat=scorestreat[,7],psd1control=scorescontrol[,6],psd2control=scorescontrol[,7],
       idx_treat_without_trimmed=idx_treat_without_trimmed,idx_control_without_trimmed=idx_control_without_trimmed)
}

# scores for dynamic treatment effect estimation
hddyntreat=function(y2,d1,d2,x0,x1, s=NULL, trim=0.05, MLmethod="lasso", fewsplits=fewsplits, repl=repl, scoreid=scoreid){
  ybin=1*(length(unique(y2))==2 & min(y2)==0 & max(y2)==1)  # check if binary outcome
  x0=data.frame(x0); x0x1=data.frame(x0,x1); d1x0x1=data.frame(d1,x0x1);
  # cross-fitting procedure that splits sample in training an testing data
  stepsize=ceiling((1/3)*length(y2))
  nobs = min(3*stepsize,length(y2)); set.seed(1); idx = sample(nobs);
  sample1 = idx[1:stepsize]; sample2 = idx[(stepsize+1):(2*stepsize)]; sample3 = idx[(2*stepsize+1):nobs];
  score=c(); sel=c(); trimmed=c()
  for (i in 1:3){
    if (i==3) {trsample1=sample1; trsample2=sample2; tesample=sample3}
    if (i==1) {trsample1=sample2; trsample2=sample3; tesample=sample1}
    if (i==2) {trsample1=sample3; trsample2=sample1; tesample=sample2}
    # total training sample
    trsample=c(trsample1,trsample2)
    # in case that fewsplits is one, both training data are merged
    if (fewsplits==1){trsample1=c(trsample1,trsample2);trsample2=trsample1}
    if (is.null(s)) {gte=rep(1,length(tesample)); ste=gte}  # check if weighted estimation should be performed
    if (is.null(s)==0) {
      g=MLfunct(y=s[trsample], x=x0[trsample,], MLmethod=MLmethod,  ybin=1)
      if (is.na(g)) {return(NA)}
      gte=predict(g, x0[tesample,], onlySL = TRUE)$pred  # predict weighting function in test data
      ste=s[tesample]
    }
    # nuisance parameters
    p1=MLfunct(y=d1[trsample], x=x0[trsample,], MLmethod=MLmethod,  ybin=1)
    p1te=predict(p1, x0[tesample,], onlySL = TRUE)$pred  # predict ps1 in test data
    p2=MLfunct(y=d2[trsample], x=d1x0x1[trsample,], MLmethod=MLmethod, ybin=1)
    p2te=predict(p2, d1x0x1[tesample,], onlySL = TRUE)$pred  # predict ps2 in test data
    y2d1d2=MLfunct(y=y2[trsample1], x=x0x1[trsample1,], d1=d1[trsample1], d2=d2[trsample1], MLmethod=MLmethod, ybin=ybin)
    y2d1d2te=predict(y2d1d2,  x0x1[tesample,], onlySL = TRUE)$pred  # predict E[Y2|D1,D2,X0,X1] in test data
    y2d1d2tr2=predict(y2d1d2, x0x1[trsample2,], onlySL = TRUE)$pred  # predict E[Y2|D1,D2,X0,X1] in second training data
    y1d1=MLfunct(y=y2d1d2tr2, x=x0[trsample2,], d1=d1[trsample2], MLmethod=MLmethod, ybin=0)
    y1d1te=predict(y1d1, x0[tesample,], onlySL = TRUE)$pred  # predict E[E[Y2|D1,D2,X0,X1]|D1,X0] in test data
    trimmed=1*((p1te*p2te)<trim)  # observations not satisfying trimming restriction
    score=rbind(score, cbind(gte,d1[tesample],d2[tesample],y2[tesample],y2d1d2te,p1te,p2te,y1d1te,ste,trimmed))
  }
  score = score[order(idx),]  
  score
}

# simulations
numsim=1000  # number of simulations
n=2500  # sample size
p0=500  # number of covariates at baseline
s=NULL # 1 if effect estimation for subpopulation of treated in first period, else NULL
s0=p0; p1=p0; s1=p0;
effd1=1  # effect of first treatment
effd2=1  # effect of second treatment
het=0  # indicator for effect heterogeneity (switched off if zero)
decrease=1  # indicator for decreasing importance of coefficients
xcorr=1  # indicator for non-zero covariance between regressors
est=c();true=c(); se=c()
counter=0
i=1
# start the loops for the Monte Carlo replications
while (i<=numsim){
  cat('Monte Carlo replication: ', i, '\n')
  counter = counter+1
  set.seed(counter)
  if (xcorr!=0){
    sigma0=matrix(0,p0,p0)
    for (ii in 1:p0){
      for (jj in 1:p0){
        sigma0[ii,jj] =0.5^(abs(ii-jj))
      }
    }
    x0=(rmvnorm(n,rep(0,p0),sigma0))  # covariate matrix at baseline
  }
  if (decrease!=0){
    beta0=rep(0,p0)
    for (j in 1:s0) beta0[j]=(.4/j^4)
    beta1=rep(0,p1)
    for (j in 1:s1) beta1[j]=(.4/j^4)
  }
  d1=(x0%*%beta0+rnorm(n)>0)*1  # equation of first treatment in period 1
  if (xcorr!=0){
    sigma1=matrix(0,p1,p1)
    for (ii in 1:p1){
      for (jj in 1:p1){
        sigma1[ii,jj] =0.5^(abs(ii-jj))
      }
    }
    x1=(rmvnorm(n,rep(0,p1),sigma1))  # covariate matrix of period 1
  }
  d2=(x0%*%beta0+x1%*%beta1+0.3*d1+rnorm(n)>0)*1  # equation of second treatment in period 2
  u=rnorm(n)
  y2=x0%*%beta0+x1%*%beta1+effd1*d1+het*x0[,1]*d1+effd2*d2+u  # outcome equation in period 2
  y2treat=x0%*%beta0+x1%*%beta1+effd1+het*x0[,1]+effd2+u
  y2control=x0%*%beta0+x1%*%beta1+u
  if (is.null(s)==0) {y2treat=y2treat[d1==1]; y2control=y2control[d1==1]; s=d1}
  temp=dyntreatDML(y2,d1,d2,x0,x1, s=s, d1treat=1, d2treat=1, d1control=0, d2control=0, trim=0.01, MLmethod="lasso", fewsplits=FALSE, repl=i)
  est=c(est, temp$effect); se=c(se,temp$se); true=c(true,mean(y2treat)-mean(y2control))
  i=i+1
}

# results
true_effect = mean(true)
absolute_bias = abs(mean(est-true))
standard_deviation = sd(est)
average_se = mean(se)
rmse = sqrt(mean((est-true)^2))
coverage = mean((est+se*1.959964>=true) & (est-se*1.959964<=true))
